import logging
import socket
from argparse import Namespace
from collections.abc import Mapping, Sequence

import ray
import torch
import torch.distributed as dist
from ray.actor import ActorHandle
from torch.distributed.tensor import DTensor

try:
    from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions  # type: ignore[import]
except ImportError:
    from sglang.srt.patch_torch import monkey_patch_torch_reductions  # type: ignore[import]

from sglang.srt.utils import MultiprocessingSerializer
from tqdm import tqdm

from slime.utils.distributed_utils import init_process_group
from slime.utils.memory_utils import clear_memory
from slime.utils.types import ParamInfo

try:
    try:
        from sglang.srt.weight_sync.tensor_bucket import FlattenedTensorBucket  # type: ignore[import]
    except ImportError:
        from sglang.srt.model_executor.model_runner import FlattenedTensorBucket  # type: ignore[import]

    use_flattened_tensor_bucket = True
except ImportError:
    use_flattened_tensor_bucket = False


logger = logging.getLogger(__name__)


def get_param_info_buckets(
    args: Namespace, weights: Mapping[str, Mapping[str, torch.Tensor]]
) -> list[list[ParamInfo]]:
    """Build `ParamInfo` buckets capped by `args.update_weight_buffer_size`.

    - Expects `weights["actor"]` to map parameter names to CPU tensors.
    - Each `ParamInfo.size` is computed in bytes (`numel * element_size`).
    - Buckets are sorted by parameter name for determinism.
    - Returns a list of buckets, each a list of `ParamInfo`.
    """
    # Create ParamInfo objects for each parameter
    param_infos = []
    rank = dist.get_rank()

    for name, param in weights["actor"].items():
        param_infos.append(
            ParamInfo(
                name=name,
                dtype=param.dtype,
                shape=param.shape,
                attrs={},  # FSDP doesn't need complex tensor parallel attrs
                size=param.numel() * param.element_size(),
                src_rank=rank,  # All parameters available on all ranks for FSDP
            )
        )

    # Sort by name for consistency
    param_infos = sorted(param_infos, key=lambda info: info.name)

    # Create buckets based on buffer size (similar to Megatron)
    param_info_buckets = [[]]
    buffer_size = 0
    buffer_size_limit = args.update_weight_buffer_size

    for info in param_infos:
        param_size = info.size

        if buffer_size + param_size > buffer_size_limit and len(param_info_buckets[-1]) > 0:
            param_info_buckets.append([])
            buffer_size = 0
        param_info_buckets[-1].append(info)
        buffer_size += param_size

    return param_info_buckets


class UpdateWeightFromTensor:
    """Push model weights to rollout engines using tensors.

    Streams parameters in size-bounded buckets; optionally groups tensors by dtype
    and flattens per dtype, gathers per-rank blobs to the source, and issues one
    RPC per dtype per bucket (or one per bucket if not flattened).
    """

    def __init__(
        self,
        args: Namespace,
        model: torch.nn.Module,
        weights: Mapping[str, Mapping[str, torch.Tensor]] | None,
    ) -> None:
        self.args = args
        self.model = model
        self.weights = weights  # CPU parameter storage

        # Validate weights initialization
        if self.weights is None:
            raise RuntimeError("weights cannot be None - CPU parameter storage is required for weight updates")

        # Create parameter info buckets once during initialization (like Megatron)
        self.param_info_buckets = get_param_info_buckets(self.args, self.weights)

        # FSDP v2 model expected

        # Set up tensor parallel configuration for SGLang
        self.tp_size = args.rollout_num_gpus_per_engine
        # tp_rank will be set during connect_rollout_engines based on the IPC group

    def connect_rollout_engines(
        self,
        rollout_engines: Sequence[ActorHandle],
        rollout_engine_lock: ActorHandle | None,
    ) -> None:
        """Attach rollout engines and create per-engine IPC (Gloo) groups.

        Sets the gather source rank, engine handle, and `tp_rank` within the
        engine's local group.
        """
        self.rollout_engines = rollout_engines

        # Here we assume the gpu id of rollout engines and train actors are the same.
        for i, engine in enumerate(self.rollout_engines):
            start_rank = i * self.args.rollout_num_gpus_per_engine
            end_rank = (i + 1) * self.args.rollout_num_gpus_per_engine
            group_ranks = list(range(start_rank, end_rank))
            new_group = dist.new_group(
                ranks=group_ranks,
                backend="gloo",
            )
            if dist.get_rank() in group_ranks:
                self._ipc_gather_src = start_rank
                self._ipc_gather_group = new_group
                self._ipc_engine = engine
                # Calculate TP rank within this SGLang engine group
                self.tp_rank = dist.get_rank() - start_rank

    @torch.no_grad()
    def update_weights(self) -> None:
        """Send weights over IPC using bucket-based loading."""

        monkey_patch_torch_reductions()

        logger.info("Using bucket-based loading from CPU storage")
        if self.param_info_buckets is None:
            raise RuntimeError("Parameter info buckets not initialized")

        for param_infos in self.param_info_buckets:
            # Load only the parameters in this bucket from CPU to GPU
            named_tensors_batch = []
            for param_info in param_infos:
                cpu_param = self.weights["actor"][param_info.name]
                gpu_param = cpu_param.to(device=torch.cuda.current_device(), non_blocking=True)
                named_tensors_batch.append((param_info.name, gpu_param))

            torch.cuda.synchronize()

            # Use flattened bucket approach similar to Megatron
            if use_flattened_tensor_bucket:
                logger.info("Using flattened tensor bucket")
                # Group tensors by dtype (same as Megatron)
                named_tensors_by_dtypes = {}
                for name, tensor in named_tensors_batch:
                    dtype = tensor.dtype
                    if dtype not in named_tensors_by_dtypes:
                        named_tensors_by_dtypes[dtype] = []
                    named_tensors_by_dtypes[dtype].append((name, tensor))

                # Create flattened bucket for each dtype group
                serialized_tensors = []
                for dtype, named_tensors in named_tensors_by_dtypes.items():
                    flattened_tensor_bucket = FlattenedTensorBucket(named_tensors=named_tensors)
                    metadata = flattened_tensor_bucket.get_metadata()
                    flattened_tensor_data = {
                        "flattened_tensor": flattened_tensor_bucket.get_flattened_tensor(),
                        "metadata": metadata,
                    }
                    serialized_tensors.append(
                        MultiprocessingSerializer.serialize(flattened_tensor_data, output_str=True)
                    )
            else:
                # Fallback to non-flattened approach
                serialized_tensors = MultiprocessingSerializer.serialize(named_tensors_batch, output_str=True)

            del named_tensors_batch
            clear_memory()

            if self._ipc_gather_src == dist.get_rank():
                # On rank 0, prepare a list to hold the gathered batches from all ranks.
                gathered_serialized_batches = [None for _ in range(dist.get_world_size(self._ipc_gather_group))]
            else:
                gathered_serialized_batches = None

            # Gather the serialized batches from all ranks to rank 0.
            dist.gather_object(
                obj=serialized_tensors,
                object_gather_list=gathered_serialized_batches,
                dst=self._ipc_gather_src,
                group=self._ipc_gather_group,
            )
            del serialized_tensors
            clear_memory()

            if dist.get_rank() == self._ipc_gather_src:
                if use_flattened_tensor_bucket:
                    # Handle flattened bucket format (same as Megatron approach)
                    # Each rank may have multiple dtype buckets
                    # TODO: here we assume all ranks have the same number of dtypes
                    num_dtypes = len(gathered_serialized_batches[0])
                    for i in range(num_dtypes):
                        kwargs = {
                            "serialized_named_tensors": [tensors[i] for tensors in gathered_serialized_batches],
                            "load_format": "flattened_bucket",
                            "flush_cache": False,
                        }
                        ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs)
                        ray.get(ref)
                else:
                    # Non-flattened approach
                    kwargs = {
                        "serialized_named_tensors": gathered_serialized_batches,
                        "flush_cache": False,
                    }
                    ref = self._ipc_engine.update_weights_from_tensor.remote(**kwargs)
                    ray.get(ref)

                del gathered_serialized_batches, kwargs
                clear_memory()

        if dist.get_rank() == self._ipc_gather_src:
            ref = self._ipc_engine.flush_cache.remote()
            ray.get(ref)
            clear_memory()


class UpdateWeightFromDistributed:
    """Broadcast weights via a temporary NCCL group to rollout engines."""

    def __init__(self, args: Namespace, model: torch.nn.Module, weights) -> None:
        self.args = args
        self.model = model
        self.weights = weights  # CPU parameter storage

        # Validate weights initialization
        if self.weights is None:
            raise RuntimeError("weights cannot be None - CPU parameter storage is required for weight updates")

        # Create parameter info buckets for bucket-based loading
        self.param_info_buckets = get_param_info_buckets(self.args, self.weights)

    def connect_rollout_engines(
        self,
        rollout_engines: Sequence[ActorHandle],
        rollout_engine_lock: ActorHandle | None,
    ) -> None:
        """On rank 0, initialize a temporary NCCL group for parameter broadcast."""
        self.rollout_engines = rollout_engines
        self.rollout_engine_lock = rollout_engine_lock

        # For TP:
        #   1. AllGather paramters to rank 0
        #   2. Broadcast parameters from rank 0 to all sglang engines
        self._is_src_rank = dist.get_rank() == 0
        if self._is_src_rank:
            self._group_name = "slime"
            master_address = ray._private.services.get_node_ip_address()
            with socket.socket() as sock:
                sock.bind(("", 0))
                master_port = sock.getsockname()[1]
            ## TODO: why +1?
            world_size = self.args.rollout_num_gpus + 1

            refs = [
                engine.init_weights_update_group.remote(
                    master_address,
                    master_port,
                    i * self.args.rollout_num_gpus_per_engine + 1,
                    world_size,
                    self._group_name,
                    backend="nccl",
                )
                for i, engine in enumerate(self.rollout_engines)
            ]
            self._model_update_groups = init_process_group(
                backend="nccl",
                init_method=f"tcp://{master_address}:{master_port}",
                world_size=world_size,
                rank=0,
                group_name=self._group_name,
            )
            ray.get(refs)

    @torch.no_grad()
    def update_weights(self) -> None:
        """Broadcast weights in buckets to minimize memory usage and reduce communication overhead."""
        model = self.model
        torch.cuda.empty_cache()
        clear_memory()

        if self.param_info_buckets is None:
            raise RuntimeError("Parameter info buckets not initialized for distributed mode")

        for param_infos in tqdm(
            self.param_info_buckets, desc="[broadcast weight buckets]", disable=not self._is_src_rank
        ):
            # Load all parameters in this bucket from CPU to GPU
            bucket_state_dict = {}
            for param_info in param_infos:
                cpu_param = self.weights["actor"][param_info.name]
                gpu_param = cpu_param.to(device=torch.cuda.current_device(), dtype=torch.bfloat16, non_blocking=True)
                bucket_state_dict[param_info.name] = gpu_param

            torch.cuda.synchronize()

            self.request_update_params(bucket_state_dict)

            # Clean up bucket
            del bucket_state_dict
            clear_memory()

        dist.barrier()
        torch.cuda.empty_cache()
        return

    def request_update_params(self, state_dict: Mapping[str, torch.Tensor]) -> None:
        """Send names/dtypes/shapes metadata to engines, then broadcast tensors.

        Ensures tensors are contiguous; when `world_size == 1`, converts DTensors
        to full tensors prior to `dist.broadcast`.
        """
        if not self._is_src_rank or not state_dict:
            return

        refs = [
            engine.update_weights_from_distributed.remote(
                names=[name for name, _ in state_dict.items()],
                dtypes=[param.dtype for _, param in state_dict.items()],
                shapes=[param.shape for _, param in state_dict.items()],
                group_name=self._group_name,
            )
            for engine in self.rollout_engines
        ]

        # Broadcast parameters one by one with memory management
        for name, param in state_dict.items():
            torch.cuda.empty_cache()
            # Ensure tensor is contiguous and on the right device
            param_data = param.data.contiguous()

            # avoid `DTensor._op_dispatcher.dispatch` has `assert compute_mesh is not None` error
            if dist.get_world_size() == 1 and isinstance(param_data, DTensor):
                param_data = param_data.full_tensor()

            # Synchronous broadcast to avoid memory buildup
            dist.broadcast(param_data, 0, group=self._model_update_groups, async_op=False)

            # Clean up immediately after broadcast
            del param_data

        ray.get(refs)
